{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The autoreload extension is already loaded. To reload it, use:\n",
      "  %reload_ext autoreload\n"
     ]
    }
   ],
   "source": [
    "%load_ext autoreload"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "from PIL import Image\n",
    "import numpy as np\n",
    "import torch\n",
    "from torchvision import transforms"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<module 'pytorch_pretrained_vit' from '/home/luke/projects/experiments/ViT-PyTorch/pytorch_pretrained_vit/__init__.py'>"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from importlib import reload\n",
    "import pytorch_pretrained_vit\n",
    "reload(pytorch_pretrained_vit)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "name = 'L_32'\n",
    "filename = 'jax_weights/ViT-L_32.npz'\n",
    "# npz_files = {\n",
    "#     'B_16': 'jax_weights/ViT-B_16.npz',\n",
    "#     'B_32': 'jax_weights/ViT-B_32.npz',\n",
    "#     'L_16': 'jax_weights/ViT-L_16.npz',\n",
    "#     'L_32': 'jax_weights/ViT-L_32.npz',\n",
    "#     'B_16_imagenet1k': 'jax_weights/ViT-B_16_imagenet1k.npz',\n",
    "#     'B_32_imagenet1k': 'jax_weights/ViT-B_32_imagenet1k.npz',\n",
    "#     'L_16_imagenet1k': 'jax_weights/ViT-L_16_imagenet1k.npz',\n",
    "#     'L_32_imagenet1k': 'jax_weights/ViT-L_32_imagenet1k.npz',\n",
    "# }\n",
    "num_classes = 21843\n",
    "# num_classes = 1000"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "def jax_to_pytorch(k):\n",
    "    k = k.replace('Transformer/encoder_norm', 'norm')\n",
    "    k = k.replace('LayerNorm_0', 'norm1')\n",
    "    k = k.replace('LayerNorm_2', 'norm2')\n",
    "    k = k.replace('MlpBlock_3/Dense_0', 'pwff.fc1')\n",
    "    k = k.replace('MlpBlock_3/Dense_1', 'pwff.fc2')\n",
    "    k = k.replace('MultiHeadDotProductAttention_1/out', 'proj')\n",
    "    k = k.replace('MultiHeadDotProductAttention_1/query', 'attn.proj_q')\n",
    "    k = k.replace('MultiHeadDotProductAttention_1/key', 'attn.proj_k')\n",
    "    k = k.replace('MultiHeadDotProductAttention_1/value', 'attn.proj_v')\n",
    "    k = k.replace('Transformer/posembed_input', 'positional_embedding')\n",
    "    k = k.replace('encoderblock_', 'blocks.')\n",
    "    k = 'patch_embedding.bias' if k == 'embedding/bias' else k\n",
    "    k = 'patch_embedding.weight' if k == 'embedding/kernel' else k\n",
    "    k = 'class_token' if k == 'cls' else k\n",
    "    k = k.replace('head', 'fc')\n",
    "    k = k.replace('kernel', 'weight')\n",
    "    k = k.replace('scale', 'weight')\n",
    "    k = k.replace('/', '.')\n",
    "    k = k.lower()\n",
    "    return k\n",
    "\n",
    "\n",
    "def convert(npz, state_dict):\n",
    "    new_state_dict = {}\n",
    "    pytorch_k2v = {jax_to_pytorch(k): v for k, v in npz.items()}\n",
    "    for pytorch_k, pytorch_v in state_dict.items():\n",
    "        \n",
    "        # Naming\n",
    "        if 'self_attn.out_proj.weight' in pytorch_k:\n",
    "            v = pytorch_k2v[pytorch_k]\n",
    "            v = v.reshape(v.shape[0] * v.shape[1], v.shape[2])\n",
    "        elif 'self_attn.in_proj_' in pytorch_k:\n",
    "            v = np.stack((pytorch_k2v[pytorch_k + '*q'], \n",
    "                          pytorch_k2v[pytorch_k + '*k'], \n",
    "                          pytorch_k2v[pytorch_k + '*v']), axis=0)\n",
    "        else:\n",
    "            if pytorch_k not in pytorch_k2v:\n",
    "                print(pytorch_k, list(pytorch_k2v.keys()))\n",
    "                assert False\n",
    "            v = pytorch_k2v[pytorch_k]\n",
    "        v = torch.from_numpy(v)\n",
    "        \n",
    "        # Sizing\n",
    "        if '.weight' in pytorch_k:\n",
    "            if len(pytorch_v.shape) == 2:\n",
    "                v = v.transpose(0, 1)\n",
    "            if len(pytorch_v.shape) == 4:\n",
    "                v = v.permute(3, 2, 0, 1)\n",
    "        if ('proj.weight' in pytorch_k):\n",
    "            v = v.transpose(0, 1)\n",
    "            v = v.reshape(-1, v.shape[-1]).T\n",
    "        if ('attn.proj_' in pytorch_k and 'weight' in pytorch_k):\n",
    "            v = v.permute(0, 2, 1)\n",
    "            v = v.reshape(-1, v.shape[-1])\n",
    "        if 'attn.proj_' in pytorch_k and 'bias' in pytorch_k:\n",
    "            v = v.reshape(-1)\n",
    "        new_state_dict[pytorch_k] = v\n",
    "    return new_state_dict\n",
    "\n",
    "\n",
    "def check_model(model, name):\n",
    "    model.eval()\n",
    "    img = Image.open('../examples/simple/img.jpg')\n",
    "    img = transforms.Compose([transforms.Resize(model.image_size), transforms.ToTensor(), transforms.Normalize(0.5, 0.5)])(img).unsqueeze(0)\n",
    "    if 'imagenet1k' in name:\n",
    "        labels_file = '../examples/simple/labels_map.txt' \n",
    "        labels_map = json.load(open(labels_file))\n",
    "        labels_map = [labels_map[str(i)] for i in range(1000)]\n",
    "        print('-----\\nShould be index 388 (panda) w/ high probability:')\n",
    "    else:\n",
    "        labels_map = open('../examples/simple/labels_map_21k.txt').read().splitlines()\n",
    "    with torch.no_grad():\n",
    "        outputs = model(img).squeeze(0)\n",
    "    for idx in torch.topk(outputs, k=3).indices.tolist():\n",
    "        prob = torch.softmax(outputs, -1)[idx].item()\n",
    "        print('[{idx}] {label:<75} ({p:.2f}%)'.format(idx=idx, label=labels_map[idx], p=prob*100))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<All keys matched successfully>"
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Load Jax weights\n",
    "npz = np.load(filename)\n",
    "\n",
    "# Load PyTorch model\n",
    "model = pytorch_pretrained_vit.ViT(name=name, pretrained=False, load_repr_layer=True)\n",
    "\n",
    "# Convert weights\n",
    "new_state_dict = convert(npz, model.state_dict())\n",
    "\n",
    "# Load into model and test\n",
    "model.load_state_dict(new_state_dict)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[3690] yellow_mountain_saxifrage                                                   (0.07%)\n",
      "[228] red_fox                                                                     (0.03%)\n",
      "[7705] amberjack                                                                   (0.02%)\n"
     ]
    }
   ],
   "source": [
    "check_model(model, name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Parameter containing:\n",
       "tensor([[-0.1088,  0.0409,  0.0360,  ..., -0.0716, -0.0018,  0.0037],\n",
       "        [-0.0945,  0.0493, -0.1079,  ..., -0.0731, -0.0225, -0.1013],\n",
       "        [ 0.0356, -0.0351,  0.0510,  ...,  0.0867, -0.0274, -0.0638],\n",
       "        ...,\n",
       "        [ 0.0978, -0.1415,  0.0287,  ...,  0.0058, -0.0644,  0.0968],\n",
       "        [ 0.0014,  0.0605, -0.0371,  ..., -0.1093, -0.1687, -0.0141],\n",
       "        [ 0.0515, -0.1962,  0.0966,  ..., -0.0401, -0.1180, -0.0609]],\n",
       "       requires_grad=True)"
      ]
     },
     "execution_count": 30,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.pre_logits.weight"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[-0.10884866, -0.09448911,  0.03560322, ...,  0.09780853,\n",
       "         0.00137385,  0.05146733],\n",
       "       [ 0.04090862,  0.04934628, -0.03513117, ..., -0.14147781,\n",
       "         0.06050842, -0.19619823],\n",
       "       [ 0.03600169, -0.10787308,  0.05104499, ...,  0.02869395,\n",
       "        -0.03711622,  0.09662006],\n",
       "       ...,\n",
       "       [-0.071619  , -0.07306355,  0.08665656, ...,  0.00580501,\n",
       "        -0.10929134, -0.04009511],\n",
       "       [-0.00184033, -0.02247062, -0.02740383, ..., -0.06440727,\n",
       "        -0.16867375, -0.11800908],\n",
       "       [ 0.00365812, -0.10132378, -0.06381997, ...,  0.09679291,\n",
       "        -0.01409991, -0.06087695]], dtype=float32)"
      ]
     },
     "execution_count": 31,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "npz['pre_logits/kernel']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
